"""
Phase 2: Bohn Fiscal Reaction Function Tests
==============================================
Does aging weaken the fiscal reaction function?

Core: Bohn (1998) — if primary balance responds positively to lagged debt,
fiscal policy is sustainable. We test whether demographics weaken this response.

Input:  fiscal_dominance/data/processed/fiscal_panel.csv
Output: fiscal_dominance/output/tables/phase2_*.csv
"""

import sys
import pandas as pd
import numpy as np
from pathlib import Path

PROJECT_DIR = Path("/mnt/c/demographics_capital_flows")
FD_DIR = PROJECT_DIR / "fiscal_dominance"
PROCESSED_DIR = FD_DIR / "data" / "processed"
TABLE_DIR = FD_DIR / "output" / "tables"

sys.path.insert(0, str(PROJECT_DIR / "multilateral" / "src"))
from model import PanelGLS


def fit_and_report(y, X, entity_ids, time_ids, feature_names, label):
    """Fit PanelGLS and return summary DataFrame."""
    model = PanelGLS()
    model.fit(y, X, entity_ids, time_ids)
    print(f"\n{'=' * 70}")
    print(f"  {label}")
    print(f"  N={model.n_obs:,}, {model.n_countries} countries, "
          f"R²={model.r_squared:.4f}, rho={model.rho:.3f}")
    print(f"{'=' * 70}")
    model.summary(feature_names=feature_names)
    result_df = model.to_dataframe(feature_names=feature_names)
    result_df['model'] = label
    result_df['n_obs'] = model.n_obs
    result_df['n_countries'] = model.n_countries
    result_df['r_squared'] = model.r_squared
    result_df['rho'] = model.rho
    return model, result_df


def main():
    print("=" * 70)
    print("PHASE 2: Bohn Fiscal Reaction Function Tests")
    print("=" * 70)

    df = pd.read_csv(PROCESSED_DIR / "fiscal_panel.csv")
    print(f"Loaded: {len(df):,} obs, {df['iso3'].nunique()} countries")

    all_results = []
    demo_vars = ['Z_1', 'Z_2', 'Z_3']

    # =================================================================
    # 1. Baseline Bohn Test
    #    primary_bal = beta * debt_lag + gamma * output_gap + controls + u
    # =================================================================
    dep_var = 'primary_bal_gdp'
    # Use HP-filtered output gap (181 countries) instead of WEO gap (27 countries)
    bohn_controls = ['output_gap_hp', 'govt_exp_gap']
    bohn_controls = [c for c in bohn_controls if c in df.columns and df[c].notna().sum() > 100]
    print(f"Bohn controls: {bohn_controls}")

    core_vars = ['debt_lag'] + bohn_controls
    est = df.dropna(subset=[dep_var] + core_vars).copy()
    print(f"\nBaseline Bohn sample: {len(est):,} obs, {est['iso3'].nunique()} countries")

    if len(est) >= 200:
        m1, r1 = fit_and_report(
            est[dep_var].values, est[core_vars].values,
            est['iso3'].values, est['year'].values,
            core_vars, "Model 1: Baseline Bohn (primary_bal ~ debt_lag)"
        )
        all_results.append(r1)

        # Extract Bohn coefficient
        idx = core_vars.index('debt_lag')
        bohn_beta = m1.beta[idx]
        bohn_p = m1.pvalues[idx]
        print(f"\n  >>> BOHN COEFFICIENT: {bohn_beta:.4f} (p={bohn_p:.4f})")
        print(f"  >>> Fiscal sustainability: {'YES (beta > 0)' if bohn_beta > 0 else 'NO (beta <= 0)'}")

    # =================================================================
    # 2. Bohn + Demographics (level effects)
    #    primary_bal = beta * debt_lag + gamma_Z * Z + controls + u
    # =================================================================
    vars_2 = ['debt_lag'] + demo_vars + bohn_controls
    est2 = df.dropna(subset=[dep_var] + vars_2).copy()

    if len(est2) >= 200:
        m2, r2 = fit_and_report(
            est2[dep_var].values, est2[vars_2].values,
            est2['iso3'].values, est2['year'].values,
            vars_2, "Model 2: Bohn + Demographics (Z levels)"
        )
        all_results.append(r2)

    # =================================================================
    # 3. Bohn + Demographic INTERACTIONS (key test)
    #    primary_bal = beta * debt_lag + gamma_Z * Z
    #                  + delta * (debt_lag x Z) + controls + u
    #    If delta < 0: aging weakens fiscal reaction
    # =================================================================
    df['debt_x_Z1'] = df['debt_lag'] * df['Z_1']
    df['debt_x_Z2'] = df['debt_lag'] * df['Z_2']
    df['debt_x_Z3'] = df['debt_lag'] * df['Z_3']

    interaction_vars = ['debt_x_Z1', 'debt_x_Z2', 'debt_x_Z3']
    vars_3 = ['debt_lag'] + demo_vars + interaction_vars + bohn_controls
    est3 = df.dropna(subset=[dep_var] + vars_3).copy()

    if len(est3) >= 200:
        m3, r3 = fit_and_report(
            est3[dep_var].values, est3[vars_3].values,
            est3['iso3'].values, est3['year'].values,
            vars_3, "Model 3: Bohn + debt_lag x Z Interactions (KEY TEST)"
        )
        all_results.append(r3)

        # Interpret interactions
        print("\n  >>> INTERACTION INTERPRETATION:")
        for iv in interaction_vars:
            idx = vars_3.index(iv)
            coef = m3.beta[idx]
            p = m3.pvalues[idx]
            sig = '***' if p < 0.01 else '**' if p < 0.05 else '*' if p < 0.1 else ''
            direction = 'WEAKENS' if coef < 0 else 'STRENGTHENS'
            print(f"    {iv}: {coef:.4f} (p={p:.4f}) {sig} — aging {direction} fiscal reaction")

    # =================================================================
    # 4. Bohn + OADR interaction (simpler alternative)
    # =================================================================
    if 'old_dep' in df.columns:
        df['debt_x_oadr'] = df['debt_lag'] * df['old_dep']
        vars_4 = ['debt_lag', 'old_dep', 'debt_x_oadr'] + bohn_controls
        est4 = df.dropna(subset=[dep_var] + vars_4).copy()

        if len(est4) >= 200:
            m4, r4 = fit_and_report(
                est4[dep_var].values, est4[vars_4].values,
                est4['iso3'].values, est4['year'].values,
                vars_4, "Model 4: Bohn + debt_lag x OADR"
            )
            all_results.append(r4)

            idx = vars_4.index('debt_x_oadr')
            print(f"\n  >>> debt_lag x OADR: {m4.beta[idx]:.4f} (p={m4.pvalues[idx]:.4f})")

    # =================================================================
    # 5. KAOPEN triple interaction
    #    debt_lag x Z x kaopen — does openness condition the aging effect?
    # =================================================================
    if 'kaopen' in df.columns:
        df['debt_Z1_kaopen'] = df['debt_lag'] * df['Z_1'] * df['kaopen']
        df['debt_Z2_kaopen'] = df['debt_lag'] * df['Z_2'] * df['kaopen']
        df['debt_Z3_kaopen'] = df['debt_lag'] * df['Z_3'] * df['kaopen']

        triple_vars = ['debt_Z1_kaopen', 'debt_Z2_kaopen', 'debt_Z3_kaopen']
        vars_5 = ['debt_lag'] + demo_vars + interaction_vars + ['kaopen'] + triple_vars + bohn_controls
        est5 = df.dropna(subset=[dep_var] + vars_5).copy()

        if len(est5) >= 200:
            m5, r5 = fit_and_report(
                est5[dep_var].values, est5[vars_5].values,
                est5['iso3'].values, est5['year'].values,
                vars_5, "Model 5: Bohn + Triple Interaction (debt x Z x KAOPEN)"
            )
            all_results.append(r5)

    # =================================================================
    # 6. Structural balance version
    # =================================================================
    struct_dep = 'structural_bal_gdp'
    if struct_dep in df.columns and df[struct_dep].notna().sum() >= 200:
        vars_6 = ['debt_lag'] + demo_vars + interaction_vars + bohn_controls
        est6 = df.dropna(subset=[struct_dep] + vars_6).copy()

        if len(est6) >= 200:
            m6, r6 = fit_and_report(
                est6[struct_dep].values, est6[vars_6].values,
                est6['iso3'].values, est6['year'].values,
                vars_6, "Model 6: Structural Bohn (struct_bal ~ debt_lag x Z)"
            )
            all_results.append(r6)

    # =================================================================
    # 7. Time-varying Bohn: rolling 15-year windows
    # =================================================================
    print(f"\n{'=' * 70}")
    print("  Time-Varying Bohn Coefficient (15-year rolling windows)")
    print("=" * 70)

    window = 15
    rolling_rows = []
    for end_year in range(1990 + window - 1, 2025):
        start_year = end_year - window + 1
        wdf = df[(df['year'] >= start_year) & (df['year'] <= end_year)].copy()
        wdf = wdf.dropna(subset=[dep_var, 'debt_lag'] + bohn_controls)

        if len(wdf) < 100:
            continue

        model = PanelGLS()
        wvars = ['debt_lag'] + bohn_controls
        model.fit(wdf[dep_var].values, wdf[wvars].values,
                  wdf['iso3'].values, wdf['year'].values)

        idx = wvars.index('debt_lag')
        rolling_rows.append({
            'window_start': start_year,
            'window_end': end_year,
            'bohn_beta': model.beta[idx],
            'bohn_se': model.se[idx],
            'bohn_p': model.pvalues[idx],
            'n_obs': model.n_obs,
            'n_countries': model.n_countries,
            'r_squared': model.r_squared,
        })
        sig = '***' if model.pvalues[idx] < 0.01 else '**' if model.pvalues[idx] < 0.05 else '*' if model.pvalues[idx] < 0.1 else ''
        print(f"  {start_year}-{end_year}: beta={model.beta[idx]:.4f} "
              f"(p={model.pvalues[idx]:.4f}) {sig} N={model.n_obs}")

    rolling_df = pd.DataFrame(rolling_rows)
    if len(rolling_df) > 0:
        rolling_df.to_csv(TABLE_DIR / "phase2_rolling_bohn.csv", index=False)

    # =================================================================
    # 8. EXPENDITURE DECOMPOSITION
    #    Does aging raise expenditure, cut revenue, or both?
    #    Z → govt_expenditure_gdp, Z → govt_revenue_gdp
    # =================================================================
    print(f"\n{'=' * 70}")
    print("  EXPENDITURE DECOMPOSITION: Z -> Revenue vs Expenditure")
    print("=" * 70)

    decomp_controls = ['debt_lag', 'kaopen', 'log_rel_opw', 'nfa_gdp_lag']
    decomp_controls = [c for c in decomp_controls if c in df.columns]

    for dep_label, dep_col in [('Government Expenditure/GDP', 'govt_expenditure_gdp'),
                                ('Government Revenue/GDP', 'govt_revenue_gdp'),
                                ('Primary Balance/GDP', 'primary_bal_gdp')]:
        if dep_col not in df.columns or df[dep_col].notna().sum() < 200:
            continue
        vars_d = demo_vars + decomp_controls
        est_d = df.dropna(subset=[dep_col] + vars_d).copy()
        if len(est_d) >= 200:
            md, rd = fit_and_report(
                est_d[dep_col].values, est_d[vars_d].values,
                est_d['iso3'].values, est_d['year'].values,
                vars_d, f"Decomp: Z -> {dep_label}"
            )
            all_results.append(rd)

    # OADR version (more intuitive)
    if 'old_dep' in df.columns:
        df['old_dep_sq'] = df['old_dep'] ** 2
        oadr_controls = ['old_dep', 'old_dep_sq'] + decomp_controls
        for dep_label, dep_col in [('Government Expenditure/GDP', 'govt_expenditure_gdp'),
                                    ('Government Revenue/GDP', 'govt_revenue_gdp')]:
            if dep_col not in df.columns:
                continue
            est_d = df.dropna(subset=[dep_col] + oadr_controls).copy()
            if len(est_d) >= 200:
                md, rd = fit_and_report(
                    est_d[dep_col].values, est_d[oadr_controls].values,
                    est_d['iso3'].values, est_d['year'].values,
                    oadr_controls, f"Decomp (OADR): Z -> {dep_label}"
                )
                all_results.append(rd)

    # Change-on-change: delta_expenditure ~ delta_old_dep
    if 'delta_old_dep' in df.columns:
        df['d_expenditure'] = df.groupby('iso3')['govt_expenditure_gdp'].diff()
        df['d_revenue'] = df.groupby('iso3')['govt_revenue_gdp'].diff()

        for dep_label, dep_col in [('Delta Expenditure/GDP', 'd_expenditure'),
                                    ('Delta Revenue/GDP', 'd_revenue')]:
            if dep_col not in df.columns:
                continue
            chg_vars = ['delta_old_dep'] + [c for c in decomp_controls if c != 'debt_lag']
            est_d = df.dropna(subset=[dep_col] + chg_vars).copy()
            if len(est_d) >= 200:
                md, rd = fit_and_report(
                    est_d[dep_col].values, est_d[chg_vars].values,
                    est_d['iso3'].values, est_d['year'].values,
                    chg_vars, f"Decomp (changes): {dep_label} ~ delta_OADR"
                )
                all_results.append(rd)

    # Summary: compare Z_1 coefficients across expenditure, revenue, balance
    print(f"\n{'=' * 70}")
    print("  DECOMPOSITION SUMMARY: Z_1 effect on expenditure vs revenue")
    print("=" * 70)
    decomp_results = [r for r in all_results if 'Decomp' in r['model'].values[0]]
    if decomp_results:
        decomp_df = pd.concat(decomp_results, ignore_index=True)
        z1_decomp = decomp_df[decomp_df['variable'] == 'Z_1'][
            ['model', 'coefficient', 'std_error', 'p_value', 'n_obs']
        ]
        if len(z1_decomp) > 0:
            print(z1_decomp.to_string(index=False, float_format='%.4f'))

    # =================================================================
    # Save all results
    # =================================================================
    if all_results:
        results_df = pd.concat(all_results, ignore_index=True)
        results_df.to_csv(TABLE_DIR / "phase2_bohn_results.csv", index=False)
        print(f"\n{'=' * 70}")
        print(f"Saved: {TABLE_DIR / 'phase2_bohn_results.csv'}")
        print(f"  {len(results_df)} coefficient rows across {results_df['model'].nunique()} models")

        # Key summary
        print(f"\n{'=' * 70}")
        print("KEY FINDINGS SUMMARY")
        print("=" * 70)
        for model_name in results_df['model'].unique():
            mdf = results_df[results_df['model'] == model_name]
            debt_row = mdf[mdf['variable'] == 'debt_lag']
            if len(debt_row) > 0:
                coef = debt_row['coefficient'].values[0]
                p = debt_row['p_value'].values[0]
                r2 = mdf['r_squared'].values[0]
                print(f"  {model_name}:")
                print(f"    Bohn beta = {coef:.4f} (p={p:.4f}), R² = {r2:.4f}")

    return results_df if all_results else pd.DataFrame()


if __name__ == "__main__":
    results = main()
